import torch

# 创建两个示例张量
skip_connection = torch.randn(1, 3, 224, 224)  # 形状为 (batch_size, channels1, height, width)
x = torch.randn(1, 6, 224, 224)  # 形状为 (batch_size, channels2, height, width)

# 在通道维度上拼接两个张量
result = torch.cat((skip_connection, x), dim=1)

print("skip_connection 的形状:", skip_connection.shape)
print("x 的形状:", x.shape)
print("拼接后结果的形状:", result.shape)
# 输出结果：
# skip_connection 的形状: torch.Size([1, 3, 224, 224])
# x 的形状: torch.Size([1, 6, 224, 224])
# 拼接后结果的形状: torch.Size([1, 9, 224, 224])